Skip to content

Modern Bert Support#15641

Merged
CISC merged 107 commits intoggml-org:masterfrom
ryan-mangeno:modern-bert-support
Dec 22, 2025
Merged

Modern Bert Support#15641
CISC merged 107 commits intoggml-org:masterfrom
ryan-mangeno:modern-bert-support

Conversation

@ryan-mangeno
Copy link
Contributor

@ryan-mangeno ryan-mangeno commented Aug 28, 2025

adding support to run granite embedding small, and it primarily pulls the modern bert architecture - https://huggingface.co/ibm-granite/granite-embedding-small-english-r2, currently working on it still, havent figured out the pre-tokenizer type or if I need to impliment it, also for the ubatch size the assert fails in llama-graph.cpp, hacked it to accept ubatch size of 1 for testing, but it seems to keep failing there and not sure why,

if I comment out of the line in llama-graph.cpp

assert(!ubatch.equal_seqs());

then it works

@ryan-mangeno ryan-mangeno marked this pull request as draft August 28, 2025 17:05
@ryan-mangeno
Copy link
Contributor Author

ryan-mangeno commented Aug 28, 2025

@gabe-l-hart thanks in advance :)

@ryan-mangeno
Copy link
Contributor Author

@gabe-l-hart thanks in advance :)

also realizing this a little late haha, but should I be changing all of the modern bert stuff to a granite embedding macro like LLM_ARCH_GRANITE_EMBD or keep it as is

@CISC
Copy link
Member

CISC commented Aug 28, 2025

You may want to check out an earlier attempt at ModernBert in #14014

@gabe-l-hart
Copy link
Collaborator

Thanks for getting this together @ryan-mangeno and thanks for pointing out the previous work @CISC. Ryan, let me know if/when you've looked over that PR and found anything to fix and I'll take a pass at review.

@gabe-l-hart
Copy link
Collaborator

also realizing this a little late haha, but should I be changing all of the modern bert stuff to a granite embedding macro like LLM_ARCH_GRANITE_EMBD or keep it as is

In general, we want to keep things as generic as possible, so since this uses the ModernBertModel architecture from transformers, it's best to keep the implementation here similarly robust unless there's a concrete reason to subset the transformers architecture to just work for granite (eg there's some non-trivial code path in the transformers version that would make sense as a separate architecture).

@github-actions github-actions bot added the python python script changes label Aug 28, 2025
@ryan-mangeno
Copy link
Contributor Author

Thanks for getting this together @ryan-mangeno and thanks for pointing out the previous work @CISC. Ryan, let me know if/when you've looked over that PR and found anything to fix and I'll take a pass at review.

will do

@ryan-mangeno
Copy link
Contributor Author

ryan-mangeno commented Sep 3, 2025

@gabe-l-hart im looking into modern berts research paper, I cant find a mention of symmetric sliding window attention but rather local sliding window attention so I am going to opt to use LLAMA_SWA_TYPE_LOCAL versus LLAMA_SWA_TYPE_SYMMETRIC used in the previous attempt. It also uses global attention every third layer so I am going to implement this stuff and then it should be ready for a review :)

@gabe-l-hart
Copy link
Collaborator

@ryan-mangeno That sounds good! I haven't unpacked any of those mechanics myself, but can try to get into it if you get stuck.

… per previous attempt, added local sliding window attention that alternates every third layer
@ryan-mangeno
Copy link
Contributor Author

@ryan-mangeno That sounds good! I haven't unpacked any of those mechanics myself, but can try to get into it if you get stuck.

ok 👍 , made some changes but not sure if its fully ready yet, I will ping you when I think its ready if thats ok

@ryan-mangeno
Copy link
Contributor Author

ryan-mangeno commented Sep 4, 2025

status update - I found out that modern bert uses an alternating rope method , per https://arxiv.org/pdf/2412.13663

In ModernBERT, every third layer employs global
attention with a RoPE theta of 160,000 and the
remaining layers use a 128 token, local sliding window attention with a RoPE theta of 10,000.

I am currently figuring out how to implement this

Copy link
Member

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just nits left, either remove the cls.predictions mappings in modify_tensors or move them to tensor_mapping, otherwise we can merge when resolved and model is still verified working.

ryan-mangeno and others added 2 commits December 22, 2025 15:51
Co-authored-by: Sigbjørn Skjæret <[email protected]>
Co-authored-by: Sigbjørn Skjæret <[email protected]>
@ryan-mangeno
Copy link
Contributor Author

ryan-mangeno commented Dec 22, 2025

Just nits left, either remove the cls.predictions mappings in modify_tensors or move them to tensor_mapping, otherwise we can merge when resolved and model is still verified working.

awesome !! thanks all for the constant support on this pr!

@ryan-mangeno
Copy link
Contributor Author

Just nits left, either remove the cls.predictions mappings in modify_tensors or move them to tensor_mapping, otherwise we can merge when resolved and model is still verified working.

BASELINE

Embedding shape: (1, 384)
Embedding vector: [[ 0.47021738 -0.08181865 -0.97021395 0.10116822 -0.16487181 -0.4128406
-0.28690535 -0.6374485 ]]

llama.cpp

llama.cpp command: /Users/ryanmangeno/Projects/gits/llama-fix/llama.cpp/build/bin/llama-embedding -m /Users/ryanmangeno/models/granite-embd.gguf -p "hello world" --temp 0 --embd-normalize -1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
llama.cpp Embedding shape: (384,)
llama.cpp Embedding vector: [ 0.467847 -0.10217 -0.970384 0.095426 -0.17978 -0.402593 -0.299406
-0.641923]

COSINE SIMILARITY: [0.99995315]

BASELINE

Embedding shape: (1, 384)
Embedding vector: [[ 1.265958 0.05745688 -0.1299568 1.3856939 0.06200486 -1.2863939
-0.29490948 1.1680877 ]]

llama.cpp

llama.cpp command: /Users/ryanmangeno/Projects/gits/llama-fix/llama.cpp/build/bin/llama-embedding -m /Users/ryanmangeno/models/granite-embd.gguf -p "tell me a story about a developer and their dog" --temp 0 --embd-normalize -1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
llama.cpp Embedding shape: (384,)
llama.cpp Embedding vector: [ 1.263891 0.048926 -0.122303 1.387904 0.056914 -1.268719 -0.309265
1.171175]

COSINE SIMILARITY: [0.9998751]

BASELINE

Embedding shape: (1, 384)
Embedding vector: [[ 0.46219334 -0.2236906 -1.063257 0.92421275 0.8207395 -0.04330844
-0.43593448 -0.04913698]]

llama.cpp

llama.cpp command: /Users/ryanmangeno/Projects/gits/llama-fix/llama.cpp/build/bin/llama-embedding -m /Users/ryanmangeno/models/granite-embd.gguf -p "123sfg this is a r@nd0m t35t" --temp 0 --embd-normalize -1
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using tokenizers before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
llama.cpp Embedding shape: (384,)
llama.cpp Embedding vector: [ 0.515068 -0.271638 -1.040524 0.929816 0.825909 -0.105021 -0.429643
-0.036336]

COSINE SIMILARITY: [0.99968832]

ran the script gabe had again after all the changes and seems its holding up

Copy link
Member

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, some final nits.

@CISC
Copy link
Member

CISC commented Dec 22, 2025

Since this is still missing head.norm support it will be merged as Granite Embedding support, follow up PRs for full Modern Bert support will be appreciated.

@CISC CISC merged commit dfc959b into ggml-org:master Dec 22, 2025
70 of 72 checks passed
@ryan-mangeno
Copy link
Contributor Author

ryan-mangeno commented Dec 22, 2025

Since this is still missing head.norm support it will be merged as Granite Embedding support, follow up PRs for full Modern Bert support will be appreciated.

ok great! I will start working on full modern bert support , thanks again for all the help and review!!! @CISC @gabe-l-hart

Nexesenex added a commit to Nexesenex/croco.cpp that referenced this pull request Dec 23, 2025
@ryan-mangeno ryan-mangeno deleted the modern-bert-support branch December 23, 2025 21:11
@ryan-mangeno ryan-mangeno restored the modern-bert-support branch December 23, 2025 21:18
Anico2 added a commit to Anico2/llama.cpp that referenced this pull request Jan 15, 2026
ModernBERT but without `head.norm` so will currently fail to convert and run any other ModernBERT models, PRs with `head.norm` support welcome!

* constants and tensor mappings for modern bert support, model not supported yet but working on getting conversion to work for encoder only

* conversion now working, hf -> gguf

* working on support, now working on building graph

* some cleanup

* cleanup

* continuing

* correct tensor shape for qkv

* fixed tensor mappings and working on buildin graph

* tensor debugging now works -> (llama-eval-callback), instead of simulated gate split with views, GEGLU is now used which does exactly this

* cleanup

* cleanup

* cleanup

* more cleanup

* ubatch issues, the assert for checking equal seqs in llama-graph.cpp when building attention  keeps failing, setting ubatch size to 1 when running llama-embedding with --ubatch-size 1 makes it work, but needs to be looked into more

* added cls token per previous modern bert attempt, still working on checking out the rest

* fixed pre tokenizer and still working through previous pr

* working through previous attemp, implimented more accurate conversion per previous attempt, added local sliding window attention that alternates every third layer

* fixed pre tokenizer

* working on swa with local and global alternating attention

* some cleanup and now fails on build attn

* starting to work, and some cleanup, currently failing on last layer construction in graph build

* alternating rope implemented and modern bert graph build succeeds

* fixed asser for equal ubatch seq

* cleanup

* added mask check in vocab

* fixed alternating rope, the hparams.rope_freq_base_train and hparams.rope_freq_base_train_swa were the same and i set them to correct values

* reuse variable

* removed repeat

* standard swa method can be used instead of a new enum being LLAMA_SWA_TYPE_LOCAL

* correct swa layer indexing, is supposed to be 0, 3, 6 ... instead of 1, 4, 7 ...

* more modular hparam setting

* replaced attn out norm with ffn_norm and cosine similarity between hf embds and llama.cpp embds went way up, from 0.05 to 0.24, replaced the cacheless kv with swa todo per the previous conversion

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update convert_hf_to_gguf_update.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-vocab.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-graph.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-arch.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* removed redundant hparam set

* enums for model sizes

* conversion for modern-bert model supported rather than just granite-small

* Update src/llama-model.cpp

Co-authored-by: Gabe Goodhart <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Gabe Goodhart <[email protected]>

* fixed ordering of enum for freq_base_swa

* fixed where I added residual, now gives much much better embeddings~

* readded cacheless logic

* removing whitespace

* conversion now working for swa pattern - dense every n layers

* modern bert put into seperate src file

* removing whitespace

* fixed whitespace and newline errors in editorconfig job

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* better naming convention, n_swa_pattern -> swa_period

* reusing sliding_window_pattern key rather than making new dense_every_n_layers key, and adding writing and reading support

* fixing pyright type-check fail

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-hparams.h

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model-saver.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model-loader.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model-loader.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model-loader.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* added descriptions in llama-model

* fixed tensor mappings for conversion

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* mapping name for size

* nits

* unused

---------

Co-authored-by: Sigbjørn Skjæret <[email protected]>
Co-authored-by: Gabe Goodhart <[email protected]>
blime4 referenced this pull request in blime4/llama.cpp Feb 5, 2026
ModernBERT but without `head.norm` so will currently fail to convert and run any other ModernBERT models, PRs with `head.norm` support welcome!

* constants and tensor mappings for modern bert support, model not supported yet but working on getting conversion to work for encoder only

* conversion now working, hf -> gguf

* working on support, now working on building graph

* some cleanup

* cleanup

* continuing

* correct tensor shape for qkv

* fixed tensor mappings and working on buildin graph

* tensor debugging now works -> (llama-eval-callback), instead of simulated gate split with views, GEGLU is now used which does exactly this

* cleanup

* cleanup

* cleanup

* more cleanup

* ubatch issues, the assert for checking equal seqs in llama-graph.cpp when building attention  keeps failing, setting ubatch size to 1 when running llama-embedding with --ubatch-size 1 makes it work, but needs to be looked into more

* added cls token per previous modern bert attempt, still working on checking out the rest

* fixed pre tokenizer and still working through previous pr

* working through previous attemp, implimented more accurate conversion per previous attempt, added local sliding window attention that alternates every third layer

* fixed pre tokenizer

* working on swa with local and global alternating attention

* some cleanup and now fails on build attn

* starting to work, and some cleanup, currently failing on last layer construction in graph build

* alternating rope implemented and modern bert graph build succeeds

* fixed asser for equal ubatch seq

* cleanup

* added mask check in vocab

* fixed alternating rope, the hparams.rope_freq_base_train and hparams.rope_freq_base_train_swa were the same and i set them to correct values

* reuse variable

* removed repeat

* standard swa method can be used instead of a new enum being LLAMA_SWA_TYPE_LOCAL

* correct swa layer indexing, is supposed to be 0, 3, 6 ... instead of 1, 4, 7 ...

* more modular hparam setting

* replaced attn out norm with ffn_norm and cosine similarity between hf embds and llama.cpp embds went way up, from 0.05 to 0.24, replaced the cacheless kv with swa todo per the previous conversion

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update convert_hf_to_gguf_update.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-vocab.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-graph.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-arch.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* removed redundant hparam set

* enums for model sizes

* conversion for modern-bert model supported rather than just granite-small

* Update src/llama-model.cpp

Co-authored-by: Gabe Goodhart <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Gabe Goodhart <[email protected]>

* fixed ordering of enum for freq_base_swa

* fixed where I added residual, now gives much much better embeddings~

* readded cacheless logic

* removing whitespace

* conversion now working for swa pattern - dense every n layers

* modern bert put into seperate src file

* removing whitespace

* fixed whitespace and newline errors in editorconfig job

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* better naming convention, n_swa_pattern -> swa_period

* reusing sliding_window_pattern key rather than making new dense_every_n_layers key, and adding writing and reading support

* fixing pyright type-check fail

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-hparams.h

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model-saver.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/gguf_writer.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/models/modern-bert.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model-loader.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model-loader.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model-loader.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* added descriptions in llama-model

* fixed tensor mappings for conversion

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* mapping name for size

* nits

* unused

---------

Co-authored-by: Sigbjørn Skjæret <[email protected]>
Co-authored-by: Gabe Goodhart <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

model Model specific python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants